package ci.web.util;

import java.io.File;
import java.net.JarURLConnection;
import java.net.URL;
import java.net.URLDecoder;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Set;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
/**
 * class-util
 * @author zhh
 */
public class ClassUtil {
    
    private static ClassLoader safe(ClassLoader loader){
        return loader==null?ClassUtil.class.getClassLoader():loader;
    }
    /**
     * 使用ClassLoader加载class
     * @param loader
     * @param className
     * @return
     * @throws Exception
     */
    public static Class<?> loadClass(ClassLoader loader, String className) throws Exception{
        try {
            return  Class.forName(className, true, safe(loader));
        }catch(Exception e){
            throw new RuntimeException(e.getMessage(), e);
        }
    }
    /**
     * 加载class
     * @param className
     * @return
     * @throws Exception
     */
    public static Class<?> loadClass(String className)
            throws Exception {
        return Class.forName(className);
    }

    /**
     * 使用ClassLoader加载class并实例化
     * @param loader
     * @param className
     * @return
     * @throws Exception
     */
    public static Object newInstance(ClassLoader loader, String className)
            throws Exception {
        try {
            Class<?> clazz = Class.forName(className, true, safe(loader));
            return clazz.newInstance();
        }catch(Exception e){
            throw new RuntimeException(e.getMessage(), e);
        }
    }
    /**
     * 加载class并实例化
     * @param loader
     * @param className
     * @return
     * @throws Exception
     */
    public static Object newInstance(String className)
            throws Exception {
        return newInstance(ClassUtil.class.getClassLoader(), className);
    }
    /**
     * 获取packageName前缀的所有类
     * @param packageName
     * @return
     * @throws Exception
     */
    public static Set<Class<?>> getAllClass(String packageName)
            throws Exception {
        return getAllClass(ClassUtil.class.getClassLoader(), packageName);
    }
    /**
     * 获取packageName前缀的所有类
     * @param loader
     * @param packageName
     * @return
     * @throws Exception
     */
    public static Set<Class<?>> getAllClass(ClassLoader loader,
            String packageName) throws Exception {
        loader = safe(loader);
        HashSet<Class<?>> set = new HashSet<>();
        String packagePath = packageName.replaceAll("\\.", File.separatorChar=='/'?"/":"\\\\");
        Enumeration<URL> resources = loader.getResources(packagePath);
        if (resources != null) {
            while (resources.hasMoreElements()) {
                URL url = resources.nextElement();
                String protocol = url.getProtocol();
                if ("file".equals(protocol)) {
                    String path = URLDecoder.decode(url.getFile(), "utf-8");
                    set.addAll(getAllClassFromDir(loader, new File(path),
                            new File(path.replace(packagePath, "")).getAbsolutePath()+File.separatorChar));
                } else if ("jar".equalsIgnoreCase(protocol)) {
                    try(JarFile jar = ((JarURLConnection) url.openConnection()).getJarFile()){
                        set.addAll(getAllClassFromJar(loader, jar, packageName));
                    }
                }
            }
        }
        return set;
    }
    /**
     * 从文件夹获取所有类
     * @param loader
     * @param directory
     * @param baseDir
     * @return
     * @throws Exception
     */
    public static Set<Class<?>> getAllClassFromDir(ClassLoader loader, File directory,
            String baseDir) throws Exception {
        HashSet<Class<?>> classes = new HashSet<>();
        if (!directory.exists()) {
            return classes;
        } else {
            File[] files = directory.listFiles();
            for (int i = 0; i < files.length; ++i) {
                File file = files[i];
                if (file.isDirectory()) {
                    classes.addAll(getAllClassFromDir(loader, file, baseDir));
                } else {
                    String fileName = file.getName();
                    if (fileName.endsWith(".class") && file.getAbsolutePath().startsWith(baseDir)) {
//                        System.out.println(file.getAbsolutePath());
//                        System.out.println(baseDir);
//                        System.out.println(file.getAbsolutePath().replace(baseDir, ""));
                        String name = file.getAbsolutePath().replace(baseDir, "").replace(".class", "").replaceAll("[/\\\\]", ".");
                        classes.add(loadClass(loader, name));
                    }
                }
            }

            return classes;
        }
    }
    /**
     * 从jar获取所有类
     * @param loader
     * @param jarFile
     * @param packageName
     * @return
     * @throws Exception
     * @throws ClassNotFoundException
     */
    public static Set<Class<?>> getAllClassFromJar(ClassLoader loader, JarFile jarFile,
            String packageName) throws Exception,
            ClassNotFoundException {
        HashSet<Class<?>> classes = new HashSet<>();
        Enumeration<JarEntry> entries = jarFile.entries();

        while (entries.hasMoreElements()) {
            JarEntry jarEntry = (JarEntry) entries.nextElement();
            if (jarEntry != null) {
                String name = jarEntry.getName();
                if (name.endsWith(".class")) {
                    name = name.replace(".class", "").replaceAll("[/\\\\]", ".").replaceAll("[/\\\\]", ".");
                    if (name.startsWith(packageName)) {
                        classes.add(loadClass(loader, name));
                    }
                }
            }
        }
        return classes;
    }

}
